{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 2: AutoWoE (WhiteBox model for binary classification on tabular data)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../../imgs/lightautoml_logo_color.png\" alt=\"LightAutoML logo\" style=\"width:100%;\"/>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Official LightAutoML github repository is [here](https://github.com/AILab-MLTools/LightAutoML)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Scorecard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../../imgs/tutorial_whitebox_report_1.png\" alt=\"Tutorial whitebox report 1\" style=\"width:100%;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Linear model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../../imgs/tutorial_whitebox_report_2.png\" alt=\"Tutorial whitebox report 2\" style=\"width:100%;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Discretization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../../imgs/tutorial_whitebox_report_3.png\" alt=\"Tutorial whitebox report 3\" style=\"width:100%;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Selection and One-dimensional analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../../imgs/tutorial_whitebox_report_4.png\" alt=\"Tutorial whitebox report 4\" style=\"width:100%;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Whitebox pipeline:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### General parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "0. Technical\n",
    "\n",
    "    - n_jobs\n",
    "    - debug\n",
    "\n",
    "\n",
    "1. Simple features typing and initial cleaning\n",
    "\n",
    "    1.1. Remove trash features\n",
    "\n",
    "        Medium:\n",
    "            - th_nan \n",
    "            - th_const \n",
    "\n",
    "    1.2. Typling (auto or user defined)\n",
    "        \n",
    "        Critical:\n",
    "            - features_type (dict) {'age': 'real', 'education': 'cat', 'birth_date': (None, (\"d\", \"wd\"), ...}\n",
    "\n",
    "    1.3. Categories and datetimes encoding\n",
    "\n",
    "        Critical:\n",
    "            - features_type (for datetimes)\n",
    "\n",
    "        Optional:\n",
    "            - cat_alpha (int) - greater means more conservative encoding\n",
    "\n",
    "\n",
    "2. Pre selection (based on BlackBox model importances)\n",
    "\n",
    "    - Critical:\n",
    "        - select_type (None or int)\n",
    "        - imp_type (if type(select_type) is int 'perm_imt'/'feature_imp')\n",
    "\n",
    "    - Optional:\n",
    "        - imt_th (float) - threshold for select_type is None\n",
    "\n",
    "\n",
    "3. Binning (discretization)\n",
    "\n",
    "    - Critical:\n",
    "        - monotonic / features_monotone_constraints\n",
    "        - max_bin_count / max_bin_count\n",
    "        - min_bin_size\n",
    "        - cat_merge_to\n",
    "        - nan_merge_to\n",
    "\n",
    "    - Medium:\n",
    "        - force_single_split\n",
    "\n",
    "    - Optional:\n",
    "        - min_bin_mults\n",
    "        - min_gains_to_split\n",
    "\n",
    "\n",
    "4. WoE estimation WoE = LN( ((% 0 in bin) / (% 0 in sample)) / ((% 1 in bin) / (% 1 in sample)) ):\n",
    "\n",
    "    - Critical:\n",
    "        - oof_woe\n",
    "\n",
    "    - Optional:\n",
    "        - woe_diff_th\n",
    "        - n_folds (if oof_woe)\n",
    "\n",
    "\n",
    "5. 2nd selection stage:\n",
    "\n",
    "    5.1. One-dimentional importance\n",
    "\n",
    "        Critical:\n",
    "            - auc_th\n",
    "\n",
    "    5.2. VIF\n",
    "\n",
    "        Critical:\n",
    "            - vif_th\n",
    "\n",
    "    5.3. Partial correlations\n",
    "\n",
    "        Critical:\n",
    "            - pearson_th\n",
    "\n",
    "\n",
    "6. 3rd selection stage (model based)\n",
    "\n",
    "    - Optional:\n",
    "        - n_folds\n",
    "        - l1_base_step\n",
    "        - l1_exp_step\n",
    "\n",
    "    - Do not touch:\n",
    "        - population_size\n",
    "        - feature_groups_count\n",
    "\n",
    "\n",
    "7. Fitting the final model\n",
    "\n",
    "    - Critical:\n",
    "        - regularized_refit\n",
    "        - p_val (if not regularized_refit)\n",
    "        - validation (if not regularized_refit)\n",
    "\n",
    "    - Optional:\n",
    "        - interpreted_model\n",
    "        - l1_base_step (if regularized_refit)\n",
    "        - l1_exp_step (if regularized_refit)\n",
    "\n",
    "8. Report generation\n",
    "\n",
    "    - report_params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pandas import Series, DataFrame\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import os\n",
    "import requests\n",
    "import joblib\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "from autowoe import AutoWoE, ReportDeco"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reading the data and train/test split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_DIR = '../data/'\n",
    "DATASET_NAME = 'jobs_train.csv'\n",
    "DATASET_FULLNAME = os.path.join(DATASET_DIR, DATASET_NAME)\n",
    "DATASET_URL = 'https://raw.githubusercontent.com/sb-ai-lab/LightAutoML/master/examples/data/jobs_train.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 13 μs, sys: 6 μs, total: 19 μs\n",
      "Wall time: 22.2 μs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "if not os.path.exists(DATASET_FULLNAME):\n",
    "    os.makedirs(DATASET_DIR, exist_ok=True)\n",
    "\n",
    "    dataset = requests.get(DATASET_URL).text\n",
    "    with open(DATASET_FULLNAME, 'w') as output:\n",
    "        output.write(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(DATASET_FULLNAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>enrollee_id</th>\n",
       "      <th>city</th>\n",
       "      <th>city_development_index</th>\n",
       "      <th>gender</th>\n",
       "      <th>relevant_experience</th>\n",
       "      <th>enrolled_university</th>\n",
       "      <th>education_level</th>\n",
       "      <th>major_discipline</th>\n",
       "      <th>experience</th>\n",
       "      <th>company_size</th>\n",
       "      <th>company_type</th>\n",
       "      <th>last_new_job</th>\n",
       "      <th>training_hours</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>8949</td>\n",
       "      <td>city_103</td>\n",
       "      <td>0.920</td>\n",
       "      <td>Male</td>\n",
       "      <td>Has relevant experience</td>\n",
       "      <td>no_enrollment</td>\n",
       "      <td>Graduate</td>\n",
       "      <td>STEM</td>\n",
       "      <td>21.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "      <td>36</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>29725</td>\n",
       "      <td>city_40</td>\n",
       "      <td>0.776</td>\n",
       "      <td>Male</td>\n",
       "      <td>No relevant experience</td>\n",
       "      <td>no_enrollment</td>\n",
       "      <td>Graduate</td>\n",
       "      <td>STEM</td>\n",
       "      <td>15.0</td>\n",
       "      <td>99.0</td>\n",
       "      <td>Pvt Ltd</td>\n",
       "      <td>5.0</td>\n",
       "      <td>47</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>11561</td>\n",
       "      <td>city_21</td>\n",
       "      <td>0.624</td>\n",
       "      <td>NaN</td>\n",
       "      <td>No relevant experience</td>\n",
       "      <td>Full time course</td>\n",
       "      <td>Graduate</td>\n",
       "      <td>STEM</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.0</td>\n",
       "      <td>83</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>33241</td>\n",
       "      <td>city_115</td>\n",
       "      <td>0.789</td>\n",
       "      <td>NaN</td>\n",
       "      <td>No relevant experience</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Graduate</td>\n",
       "      <td>Business Degree</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Pvt Ltd</td>\n",
       "      <td>0.0</td>\n",
       "      <td>52</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>666</td>\n",
       "      <td>city_162</td>\n",
       "      <td>0.767</td>\n",
       "      <td>Male</td>\n",
       "      <td>Has relevant experience</td>\n",
       "      <td>no_enrollment</td>\n",
       "      <td>Masters</td>\n",
       "      <td>STEM</td>\n",
       "      <td>21.0</td>\n",
       "      <td>99.0</td>\n",
       "      <td>Funded Startup</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19153</th>\n",
       "      <td>7386</td>\n",
       "      <td>city_173</td>\n",
       "      <td>0.878</td>\n",
       "      <td>Male</td>\n",
       "      <td>No relevant experience</td>\n",
       "      <td>no_enrollment</td>\n",
       "      <td>Graduate</td>\n",
       "      <td>Humanities</td>\n",
       "      <td>14.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "      <td>42</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19154</th>\n",
       "      <td>31398</td>\n",
       "      <td>city_103</td>\n",
       "      <td>0.920</td>\n",
       "      <td>Male</td>\n",
       "      <td>Has relevant experience</td>\n",
       "      <td>no_enrollment</td>\n",
       "      <td>Graduate</td>\n",
       "      <td>STEM</td>\n",
       "      <td>14.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>4.0</td>\n",
       "      <td>52</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19155</th>\n",
       "      <td>24576</td>\n",
       "      <td>city_103</td>\n",
       "      <td>0.920</td>\n",
       "      <td>Male</td>\n",
       "      <td>Has relevant experience</td>\n",
       "      <td>no_enrollment</td>\n",
       "      <td>Graduate</td>\n",
       "      <td>STEM</td>\n",
       "      <td>21.0</td>\n",
       "      <td>99.0</td>\n",
       "      <td>Pvt Ltd</td>\n",
       "      <td>4.0</td>\n",
       "      <td>44</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19156</th>\n",
       "      <td>5756</td>\n",
       "      <td>city_65</td>\n",
       "      <td>0.802</td>\n",
       "      <td>Male</td>\n",
       "      <td>Has relevant experience</td>\n",
       "      <td>no_enrollment</td>\n",
       "      <td>High School</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.0</td>\n",
       "      <td>999.0</td>\n",
       "      <td>Pvt Ltd</td>\n",
       "      <td>2.0</td>\n",
       "      <td>97</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19157</th>\n",
       "      <td>23834</td>\n",
       "      <td>city_67</td>\n",
       "      <td>0.855</td>\n",
       "      <td>NaN</td>\n",
       "      <td>No relevant experience</td>\n",
       "      <td>no_enrollment</td>\n",
       "      <td>Primary School</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "      <td>127</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>19158 rows × 14 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       enrollee_id      city  city_development_index gender  \\\n",
       "0             8949  city_103                   0.920   Male   \n",
       "1            29725   city_40                   0.776   Male   \n",
       "2            11561   city_21                   0.624    NaN   \n",
       "3            33241  city_115                   0.789    NaN   \n",
       "4              666  city_162                   0.767   Male   \n",
       "...            ...       ...                     ...    ...   \n",
       "19153         7386  city_173                   0.878   Male   \n",
       "19154        31398  city_103                   0.920   Male   \n",
       "19155        24576  city_103                   0.920   Male   \n",
       "19156         5756   city_65                   0.802   Male   \n",
       "19157        23834   city_67                   0.855    NaN   \n",
       "\n",
       "           relevant_experience enrolled_university education_level  \\\n",
       "0      Has relevant experience       no_enrollment        Graduate   \n",
       "1       No relevant experience       no_enrollment        Graduate   \n",
       "2       No relevant experience    Full time course        Graduate   \n",
       "3       No relevant experience                 NaN        Graduate   \n",
       "4      Has relevant experience       no_enrollment         Masters   \n",
       "...                        ...                 ...             ...   \n",
       "19153   No relevant experience       no_enrollment        Graduate   \n",
       "19154  Has relevant experience       no_enrollment        Graduate   \n",
       "19155  Has relevant experience       no_enrollment        Graduate   \n",
       "19156  Has relevant experience       no_enrollment     High School   \n",
       "19157   No relevant experience       no_enrollment  Primary School   \n",
       "\n",
       "      major_discipline  experience  company_size    company_type  \\\n",
       "0                 STEM        21.0           NaN             NaN   \n",
       "1                 STEM        15.0          99.0         Pvt Ltd   \n",
       "2                 STEM         5.0           NaN             NaN   \n",
       "3      Business Degree         0.0           NaN         Pvt Ltd   \n",
       "4                 STEM        21.0          99.0  Funded Startup   \n",
       "...                ...         ...           ...             ...   \n",
       "19153       Humanities        14.0           NaN             NaN   \n",
       "19154             STEM        14.0           NaN             NaN   \n",
       "19155             STEM        21.0          99.0         Pvt Ltd   \n",
       "19156              NaN         0.0         999.0         Pvt Ltd   \n",
       "19157              NaN         2.0           NaN             NaN   \n",
       "\n",
       "       last_new_job  training_hours  target  \n",
       "0               1.0              36     1.0  \n",
       "1               5.0              47     0.0  \n",
       "2               0.0              83     0.0  \n",
       "3               0.0              52     1.0  \n",
       "4               4.0               8     0.0  \n",
       "...             ...             ...     ...  \n",
       "19153           1.0              42     1.0  \n",
       "19154           4.0              52     1.0  \n",
       "19155           4.0              44     0.0  \n",
       "19156           2.0              97     0.0  \n",
       "19157           1.0             127     0.0  \n",
       "\n",
       "[19158 rows x 14 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = train_test_split(data.drop('enrollee_id', axis=1), test_size=0.2, stratify=data['target'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### AutoWoe: default settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "auto_woe_0 = AutoWoE(interpreted_model=True,\n",
    "                     monotonic=False,\n",
    "                     max_bin_count=5,\n",
    "                     select_type=None,\n",
    "                     pearson_th=0.9,\n",
    "                     metric_th=.505,\n",
    "                     vif_th=10.,\n",
    "                     imp_th=0,\n",
    "                     th_const=32,\n",
    "                     force_single_split=True,\n",
    "                     th_nan=0.01,\n",
    "                     th_cat=0.005,\n",
    "                     metric_tol=1e-4,\n",
    "                     cat_alpha=100,\n",
    "                     cat_merge_to=\"to_woe_0\",\n",
    "                     nan_merge_to=\"to_woe_0\",\n",
    "                     imp_type=\"feature_imp\",\n",
    "                     regularized_refit=False,\n",
    "                     p_val=0.05,\n",
    "                     verbose=2\n",
    "        )\n",
    "\n",
    "auto_woe_0 = ReportDeco(auto_woe_0, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LightGBM] [Info] Number of positive: 3033, number of negative: 9227\n",
      "[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000495 seconds.\n",
      "You can set `force_row_wise=true` to remove the overhead.\n",
      "And if memory is not enough, you can set `force_col_wise=true`.\n",
      "[LightGBM] [Info] Total Bins 518\n",
      "[LightGBM] [Info] Number of data points in the train set: 12260, number of used features: 12\n",
      "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.247390 -> initscore=-1.112582\n",
      "[LightGBM] [Info] Start training from score -1.112582\n",
      "Training until validation scores don't improve for 10 rounds\n",
      "Early stopping, best iteration is:\n",
      "[11]\tval_test's auc: 0.810634\n",
      "city processing...\n",
      "city_development_index processing...\n",
      "gender processing...\n",
      "relevant_experience processing...\n",
      "enrolled_university processing...\n",
      "education_level processing...\n",
      "major_discipline processing...\n",
      "experience processing...\n",
      "company_size processing...\n",
      "company_type processing...\n",
      "last_new_job processing...\n",
      "training_hours processing...\n",
      "dict_keys(['city', 'city_development_index', 'gender', 'relevant_experience', 'enrolled_university', 'education_level', 'major_discipline', 'experience', 'company_size', 'company_type', 'last_new_job', 'training_hours']) to selector !!!!!\n",
      "Feature selection...\n",
      "city_development_index   -0.956279\n",
      "company_size             -0.859531\n",
      "company_type             -0.412534\n",
      "experience               -0.289671\n",
      "enrolled_university      -0.259906\n",
      "education_level          -0.603299\n",
      "major_discipline         -1.683294\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "auto_woe_0.fit(train,\n",
    "               target_name=\"target\",\n",
    "              )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.05961808, 0.59490484, 0.02947085, ..., 0.14355782, 0.06430498,\n",
       "       0.0480986 ])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_prediction = auto_woe_0.predict_proba(test)\n",
    "test_prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8016469307943301"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "roc_auc_score(test['target'].values, test_prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "report_params = {\"output_path\": \"HR_REPORT_1\", # folder for report generation\n",
    "                 \"report_name\": \"WHITEBOX REPORT\",\n",
    "                 \"report_version_id\": 1,\n",
    "                 \"city\": \"Moscow\",\n",
    "                 \"model_aim\": \"Predict if candidate will work for the company\",\n",
    "                 \"model_name\": \"HR model\",\n",
    "                 \"zakazchik\": \"Kaggle\",\n",
    "                 \"high_level_department\": \"Ai Lab\",\n",
    "                 \"ds_name\": \"Btbpanda\",\n",
    "                 \"target_descr\": \"Candidate will work for the company\",\n",
    "                 \"non_target_descr\": \"Candidate will work for the company\"}\n",
    "\n",
    "auto_woe_0.generate_report(report_params, )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### AutoWoE - simpler model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "auto_woe_1 = AutoWoE(interpreted_model=True,\n",
    "                     monotonic=True,\n",
    "                     max_bin_count=4,\n",
    "                     select_type=None,\n",
    "                     pearson_th=0.9,\n",
    "                     metric_th=.505,\n",
    "                     vif_th=10.,\n",
    "                     imp_th=0,\n",
    "                     th_const=32,\n",
    "                     force_single_split=True,\n",
    "                     th_nan=0.01,\n",
    "                     th_cat=0.005,\n",
    "                     metric_tol=1e-4,\n",
    "                     cat_alpha=100,\n",
    "                     cat_merge_to=\"to_woe_0\",\n",
    "                     nan_merge_to=\"to_woe_0\",\n",
    "                     imp_type=\"feature_imp\",\n",
    "                     regularized_refit=False,\n",
    "                     p_val=0.05,\n",
    "                     verbose=2\n",
    "        )\n",
    "\n",
    "auto_woe_1 = ReportDeco(auto_woe_1, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LightGBM] [Info] Number of positive: 3033, number of negative: 9227\n",
      "[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000364 seconds.\n",
      "You can set `force_row_wise=true` to remove the overhead.\n",
      "And if memory is not enough, you can set `force_col_wise=true`.\n",
      "[LightGBM] [Info] Total Bins 518\n",
      "[LightGBM] [Info] Number of data points in the train set: 12260, number of used features: 12\n",
      "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.247390 -> initscore=-1.112582\n",
      "[LightGBM] [Info] Start training from score -1.112582\n",
      "Training until validation scores don't improve for 10 rounds\n",
      "Early stopping, best iteration is:\n",
      "[11]\tval_test's auc: 0.810634\n",
      "city processing...\n",
      "city_development_index processing...\n",
      "gender processing...\n",
      "relevant_experience processing...\n",
      "enrolled_university processing...\n",
      "education_level processing...\n",
      "major_discipline processing...\n",
      "experience processing...\n",
      "company_size processing...\n",
      "company_type processing...\n",
      "last_new_job processing...\n",
      "training_hours processing...\n",
      "dict_keys(['city', 'city_development_index', 'gender', 'relevant_experience', 'enrolled_university', 'education_level', 'major_discipline', 'experience', 'company_size', 'company_type', 'last_new_job', 'training_hours']) to selector !!!!!\n",
      "Feature selection...\n",
      "city                     -0.525685\n",
      "city_development_index   -0.482931\n",
      "company_size             -0.884190\n",
      "company_type             -0.401782\n",
      "experience               -0.272925\n",
      "enrolled_university      -0.231768\n",
      "education_level          -0.673794\n",
      "major_discipline         -1.606442\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "auto_woe_1.fit(train,\n",
    "               target_name=\"target\",\n",
    "              )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.06195668, 0.59982925, 0.03708212, ..., 0.13104366, 0.05378754,\n",
       "       0.0487648 ])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_prediction = auto_woe_1.predict_proba(test)\n",
    "test_prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7991679814815826"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "roc_auc_score(test['target'].values, test_prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "report_params = {\"output_path\": \"HR_REPORT_2\", # folder for report generation\n",
    "                 \"report_name\": \"WHITEBOX REPORT\",\n",
    "                 \"report_version_id\": 2,\n",
    "                 \"city\": \"Moscow\",\n",
    "                 \"model_aim\": \"Predict if candidate will work for the company\",\n",
    "                 \"model_name\": \"HR model\",\n",
    "                 \"zakazchik\": \"Kaggle\",\n",
    "                 \"high_level_department\": \"Ai Lab\",\n",
    "                 \"ds_name\": \"Btbpanda\",\n",
    "                 \"target_descr\": \"Candidate will work for the company\",\n",
    "                 \"non_target_descr\": \"Candidate will work for the company\"}\n",
    "\n",
    "auto_woe_1.generate_report(report_params, )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### WhiteBox preset - like TabularAutoML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightautoml.automl.presets.whitebox_presets import WhiteBoxPreset\n",
    "from lightautoml.tasks import Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = Task('binary')\n",
    "automl = WhiteBoxPreset(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LightGBM] [Info] Number of positive: 3033, number of negative: 9227\n",
      "[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000469 seconds.\n",
      "You can set `force_row_wise=true` to remove the overhead.\n",
      "And if memory is not enough, you can set `force_col_wise=true`.\n",
      "[LightGBM] [Info] Total Bins 518\n",
      "[LightGBM] [Info] Number of data points in the train set: 12260, number of used features: 12\n",
      "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.247390 -> initscore=-1.112582\n",
      "[LightGBM] [Info] Start training from score -1.112582\n",
      "Training until validation scores don't improve for 10 rounds\n",
      "Early stopping, best iteration is:\n",
      "[17]\tval_test's auc: 0.805941\n"
     ]
    }
   ],
   "source": [
    "\n",
    "train_pred = automl.fit_predict(train.reset_index(drop=True), roles={'target': 'target'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_prediction = automl.predict(test).data[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7966626448798652"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "roc_auc_score(test['target'].values, test_prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Serialization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Important note:** `auto_woe_1` is the `ReportDeco` object (the report generator object), not `AutoWoE` itself. To receive the `AutoWoE` object you can use the `auto_woe_1.model`.\n",
    "\n",
    "`ReportDeco` object usage for inference is **not** recommended for several reasons:\n",
    "- The report object needs to have the target column because of model quality metrics calculation\n",
    "- Model inference using `ReportDeco` object is slower than the usual one because of the report update procedure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "joblib.dump(auto_woe_1.model, 'model.pkl')\n",
    "model = joblib.load('model.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SQL inference query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SELECT\n",
      "  1 / (1 + EXP(-(\n",
      "    -1.11\n",
      "    -0.526*WOE_TAB.city\n",
      "    -0.483*WOE_TAB.city_development_index\n",
      "    -0.884*WOE_TAB.company_size\n",
      "    -0.402*WOE_TAB.company_type\n",
      "    -0.273*WOE_TAB.experience\n",
      "    -0.232*WOE_TAB.enrolled_university\n",
      "    -0.674*WOE_TAB.education_level\n",
      "    -1.606*WOE_TAB.major_discipline\n",
      "  ))) as PROB,\n",
      "  WOE_TAB.*\n",
      "FROM \n",
      "    (SELECT\n",
      "    CASE\n",
      "      WHEN (city IS NULL OR LOWER(CAST(city AS VARCHAR(50))) = 'nan') THEN 0\n",
      "      WHEN city IN ('city_100', 'city_102', 'city_103', 'city_116', 'city_149', 'city_159', 'city_160', 'city_45', 'city_46', 'city_64', 'city_71', 'city_73', 'city_83', 'city_99') THEN 0.213\n",
      "      WHEN city IN ('city_104', 'city_114', 'city_136', 'city_138', 'city_16', 'city_173', 'city_23', 'city_28', 'city_36', 'city_50', 'city_57', 'city_61', 'city_65', 'city_67', 'city_75', 'city_97') THEN 1.017\n",
      "      WHEN city IN ('city_11', 'city_21', 'city_74') THEN -1.455\n",
      "      ELSE -0.209\n",
      "    END AS city,\n",
      "    CASE\n",
      "      WHEN (city_development_index IS NULL OR city_development_index = 'NaN') THEN 0\n",
      "      WHEN city_development_index <= 0.6245 THEN -1.454\n",
      "      WHEN city_development_index <= 0.7915 THEN -0.121\n",
      "      WHEN city_development_index <= 0.9235 THEN 0.461\n",
      "      ELSE 1.101\n",
      "    END AS city_development_index,\n",
      "    CASE\n",
      "      WHEN (company_size IS NULL OR company_size = 'NaN') THEN -0.717\n",
      "      WHEN company_size <= 74.0 THEN 0.221\n",
      "      ELSE 0.467\n",
      "    END AS company_size,\n",
      "    CASE\n",
      "      WHEN (company_type IS NULL OR LOWER(CAST(company_type AS VARCHAR(50))) = 'nan') THEN -0.64\n",
      "      WHEN company_type IN ('Early Stage Startup', 'NGO', 'Other', 'Public Sector') THEN 0.164\n",
      "      WHEN company_type == 'Funded Startup' THEN 0.737\n",
      "      WHEN company_type == 'Pvt Ltd' THEN 0.398\n",
      "      ELSE 0\n",
      "    END AS company_type,\n",
      "    CASE\n",
      "      WHEN (experience IS NULL OR experience = 'NaN') THEN 0\n",
      "      WHEN experience <= 1.5 THEN -0.811\n",
      "      WHEN experience <= 7.5 THEN -0.319\n",
      "      WHEN experience <= 11.5 THEN 0.119\n",
      "      ELSE 0.533\n",
      "    END AS experience,\n",
      "    CASE\n",
      "      WHEN (enrolled_university IS NULL OR LOWER(CAST(enrolled_university AS VARCHAR(50))) = 'nan') THEN -0.327\n",
      "      WHEN enrolled_university == 'Full time course' THEN -0.614\n",
      "      WHEN enrolled_university == 'Part time course' THEN 0.026\n",
      "      WHEN enrolled_university == 'no_enrollment' THEN 0.208\n",
      "      ELSE 0\n",
      "    END AS enrolled_university,\n",
      "    CASE\n",
      "      WHEN (education_level IS NULL OR LOWER(CAST(education_level AS VARCHAR(50))) = 'nan') THEN 0.21\n",
      "      WHEN education_level == 'Graduate' THEN -0.166\n",
      "      WHEN education_level == 'High School' THEN 0.34\n",
      "      WHEN education_level == 'Masters' THEN 0.21\n",
      "      WHEN education_level IN ('Phd', 'Primary School') THEN 0.704\n",
      "      ELSE 0\n",
      "    END AS education_level,\n",
      "    CASE\n",
      "      WHEN (major_discipline IS NULL OR LOWER(CAST(major_discipline AS VARCHAR(50))) = 'nan') THEN 0.333\n",
      "      WHEN major_discipline == 'Arts' THEN 0.199\n",
      "      WHEN major_discipline IN ('Business Degree', 'No Major', 'Other', 'STEM') THEN -0.071\n",
      "      WHEN major_discipline == 'Humanities' THEN 0.333\n",
      "      ELSE 0\n",
      "    END AS major_discipline\n",
      "  FROM global_temp.TABLE_1) as WOE_TAB\n"
     ]
    }
   ],
   "source": [
    "sql_query = model.get_sql_inference_query('global_temp.TABLE_1')\n",
    "print(sql_query)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check the SQL query by PySpark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark = SparkSession.builder \\\n",
    "                    .master(\"local[2]\") \\\n",
    "                    .appName(\"spark-course\") \\\n",
    "                    .config(\"spark.driver.memory\", \"512m\") \\\n",
    "                    .getOrCreate()\n",
    "sc = spark.sparkContext"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark_df = spark.read.csv(DATASET_FULLNAME, header=True)\n",
    "spark_df.createGlobalTempView(\"TABLE_1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = spark.sql(sql_query).toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_prediction = model.predict_proba(data)\n",
    "full_prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(res['PROB'] - full_prediction).abs().max()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
